{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deploy comunication compression scheme in FedLab\n",
    "\n",
    "This tutorial provides comprehensive examples about implementing a communication efficiency scheme in FedLab. \n",
    "\n",
    "We take the baseline gradient compression algorithms as examples (top-k for gradient sparsification and QSGD for gradient quantization)."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compress example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "from fedlab.contrib.compressor.quantization import QSGDCompressor\n",
    "from fedlab.contrib.compressor.topk import TopkCompressor\n",
    "import torch\n",
    "\n",
    "tpk_compressor = TopkCompressor(compress_ratio=0.05) # top 5% gradient\n",
    "qsgd_compressor = QSGDCompressor(n_bit=8)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "To be compressed tensor: tensor([-0.0117, -0.2613,  0.9804, -1.1710,  0.8875, -0.3904, -0.2542, -1.9668,\n",
      "        -0.9684, -0.2456, -0.6323, -0.4404,  0.5047,  0.9951, -0.0873, -0.1383,\n",
      "         0.1960, -1.0819, -0.3392, -1.6281,  0.7332, -0.1525,  0.2921,  0.1089,\n",
      "        -1.9116,  0.0485,  0.8300,  1.2147,  0.6108,  0.4922,  0.1932,  0.7526,\n",
      "        -0.8702, -0.0484, -1.8132,  1.1468,  0.9768,  0.5359, -1.1338,  0.2096,\n",
      "        -0.8231, -1.3328, -0.8701, -0.4968,  2.0271,  0.8540, -0.2519,  0.9694,\n",
      "         0.6756,  0.0466, -1.5590,  0.1060,  1.0665, -0.0566, -1.7187, -0.3393,\n",
      "        -0.5582,  0.9904, -1.6938, -0.3719, -0.7633, -0.6630,  0.4305, -0.2383,\n",
      "        -0.3844, -0.8845, -0.6226, -0.8108, -0.6467, -0.8886,  0.6957,  0.1746,\n",
      "         0.7250,  0.0993, -0.4310,  0.0506,  0.1916,  0.3660,  0.7346,  0.0725,\n",
      "        -2.2989,  0.1391,  0.2980,  0.0408,  0.7876,  1.9719, -0.2313,  2.3823,\n",
      "         0.4999, -0.2233,  0.5530,  1.0528, -0.6531, -0.1192,  0.0406,  0.1687,\n",
      "         0.5074, -0.6520, -0.7053, -0.3894])\n",
      "Compressed results top-k values: tensor([ 2.3823, -2.2989,  2.0271,  1.9719, -1.9668])\n",
      "Compressed results top-k indices: tensor([87, 80, 44, 85,  7])\n",
      "Decompressed results: tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -1.9668,\n",
      "         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
      "         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
      "         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
      "         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
      "         0.0000,  0.0000,  0.0000,  0.0000,  2.0271,  0.0000,  0.0000,  0.0000,\n",
      "         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
      "         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
      "         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
      "         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
      "        -2.2989,  0.0000,  0.0000,  0.0000,  0.0000,  1.9719,  0.0000,  2.3823,\n",
      "         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
      "         0.0000,  0.0000,  0.0000,  0.0000])\n"
     ]
    }
   ],
   "source": [
    "# top-k\n",
    "tensor = torch.randn(size=(100,))\n",
    "shape = tensor.shape\n",
    "print(\"To be compressed tensor:\", tensor)\n",
    "\n",
    "# compress\n",
    "values, indices = tpk_compressor.compress(tensor)\n",
    "print(\"Compressed results top-k values:\",values)\n",
    "print(\"Compressed results top-k indices:\", indices)\n",
    "\n",
    "# decompress\n",
    "decompressed = tpk_compressor.decompress(values, indices, shape)\n",
    "print(\"Decompressed results:\", decompressed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "To be compressed tensor: tensor([-1.2208,  1.4187, -0.5420,  0.4409, -0.7793, -0.1413, -0.9905, -0.7609,\n",
      "         1.5335,  1.3902,  2.1191, -0.1703,  0.2943, -1.7366, -1.9368, -1.0637,\n",
      "        -1.8469,  0.3083,  1.4614, -1.0390, -0.3341, -0.1067, -0.0539,  0.2535,\n",
      "         2.2873,  0.3141,  0.5668, -1.0361, -0.2052,  0.6109, -0.7858,  0.1472,\n",
      "         0.3375, -2.0928, -0.7161,  0.0547, -0.5214,  0.5887, -0.0325,  0.4731,\n",
      "        -1.7919, -0.2484, -0.1335, -1.0788,  1.0368,  0.9461, -0.0935, -1.9321,\n",
      "        -0.9211, -0.9592,  2.4520,  0.6194, -0.6674,  0.0290,  0.0573,  0.3466,\n",
      "        -0.3336, -1.0915, -1.1358, -0.0730,  0.8704,  0.1320,  1.5872, -0.9954,\n",
      "        -0.0316, -0.0083,  0.2476,  0.4635,  0.7978,  0.9841, -1.4457,  0.6230,\n",
      "         0.2561,  0.1120,  0.5665,  0.9917, -0.7548,  0.6985,  0.1262,  0.9677,\n",
      "        -0.3080, -0.0502,  0.4727, -0.6012,  2.1698, -1.3320,  0.4133,  1.0892,\n",
      "        -0.5171, -2.3523, -1.3173,  1.2618,  0.2381, -0.4757,  2.0841,  0.0161,\n",
      "         1.7087,  0.5624, -1.1546, -0.4761])\n",
      "Compressed results QSGD norm: tensor([2.4520])\n",
      "Compressed results QSGD signs: tensor([False,  True, False,  True, False, False, False, False,  True,  True,\n",
      "         True, False,  True, False, False, False, False,  True,  True, False,\n",
      "        False, False, False,  True,  True,  True,  True, False, False,  True,\n",
      "        False,  True,  True, False, False,  True, False,  True, False,  True,\n",
      "        False, False, False, False,  True,  True, False, False, False, False,\n",
      "         True,  True, False,  True,  True,  True, False, False, False, False,\n",
      "         True,  True,  True, False, False, False,  True,  True,  True,  True,\n",
      "        False,  True,  True,  True,  True,  True, False,  True,  True,  True,\n",
      "        False, False,  True, False,  True, False,  True,  True, False, False,\n",
      "        False,  True,  True, False,  True,  True,  True,  True, False, False])\n",
      "Compressed results QSGD values: tensor([127, 148,  57,  46,  82,  15, 103,  80, 160, 146, 221,  18,  31, 181,\n",
      "        202, 112, 193,  32, 153, 109,  35,  11,   6,  26, 238,  33,  60, 108,\n",
      "         21,  64,  82,  16,  35, 219,  75,   6,  54,  61,   3,  49, 187,  26,\n",
      "         14, 112, 108,  99,   9, 202,  96, 100, 256,  65,  70,   3,   6,  36,\n",
      "         35, 114, 119,   7,  91,  13, 166, 104,   3,   1,  26,  49,  84, 103,\n",
      "        151,  65,  27,  11,  59, 104,  79,  73,  13, 101,  32,   5,  49,  63,\n",
      "        227, 139,  43, 114,  54, 245, 137, 131,  25,  50, 217,   2, 179,  59,\n",
      "        121,  49], dtype=torch.int32)\n"
     ]
    }
   ],
   "source": [
    "# qsgd\n",
    "tensor = torch.randn(size=(100,))\n",
    "shape = tensor.shape\n",
    "print(\"To be compressed tensor:\", tensor)\n",
    "\n",
    "# compress\n",
    "norm, signs, values = qsgd_compressor.compress(tensor)\n",
    "print(\"Compressed results QSGD norm:\", norm)\n",
    "print(\"Compressed results QSGD signs:\", signs)\n",
    "print(\"Compressed results QSGD values:\", values)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Decompressed results: tensor([-1.2164,  1.4175, -0.5459,  0.4406, -0.7854, -0.1437, -0.9865, -0.7662,\n",
      "         1.5325,  1.3984,  2.1167, -0.1724,  0.2969, -1.7336, -1.9347, -1.0727,\n",
      "        -1.8485,  0.3065,  1.4654, -1.0440, -0.3352, -0.1054, -0.0575,  0.2490,\n",
      "         2.2796,  0.3161,  0.5747, -1.0344, -0.2011,  0.6130, -0.7854,  0.1532,\n",
      "         0.3352, -2.0976, -0.7183,  0.0575, -0.5172,  0.5843, -0.0287,  0.4693,\n",
      "        -1.7911, -0.2490, -0.1341, -1.0727,  1.0344,  0.9482, -0.0862, -1.9347,\n",
      "        -0.9195, -0.9578,  2.4520,  0.6226, -0.6705,  0.0287,  0.0575,  0.3448,\n",
      "        -0.3352, -1.0919, -1.1398, -0.0670,  0.8716,  0.1245,  1.5899, -0.9961,\n",
      "        -0.0287, -0.0096,  0.2490,  0.4693,  0.8045,  0.9865, -1.4463,  0.6226,\n",
      "         0.2586,  0.1054,  0.5651,  0.9961, -0.7567,  0.6992,  0.1245,  0.9674,\n",
      "        -0.3065, -0.0479,  0.4693, -0.6034,  2.1742, -1.3313,  0.4119,  1.0919,\n",
      "        -0.5172, -2.3466, -1.3122,  1.2547,  0.2394, -0.4789,  2.0784,  0.0192,\n",
      "         1.7145,  0.5651, -1.1589, -0.4693])\n"
     ]
    }
   ],
   "source": [
    "# decompress\n",
    "decompressed = qsgd_compressor.decompress([norm, signs, values])\n",
    "print(\"Decompressed results:\", decompressed)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use compressor in federated learning\n",
    "\n",
    "For example on the client side, we could compress the tensors are to compressed and upload the compressed results to server. And server could decompress the tensors follows the compression agreements.\n",
    "\n",
    "In jupyter notebook, we take the standalone scenario as example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fedlab.contrib.algorithm.basic_client import SGDSerialClientTrainer, SGDClientTrainer\n",
    "from fedlab.contrib.algorithm.basic_server import SyncServerHandler\n",
    "\n",
    "class CompressSerialClientTrainer(SGDSerialClientTrainer):\n",
    "    def setup_compressor(self, compressor):\n",
    "        #self.compressor = TopkCompressor(compress_ratio=k)\n",
    "        self.compressor = compressor\n",
    "\n",
    "    @property\n",
    "    def uplink_package(self):\n",
    "        package = super().uplink_package\n",
    "        new_package = []\n",
    "        for content in package:\n",
    "            pack = [self.compressor.compress(content[0])]\n",
    "            new_package.append(pack)\n",
    "        return new_package\n",
    "\n",
    "class CompressServerHandler(SyncServerHandler):\n",
    "    def setup_compressor(self, compressor, type):\n",
    "        #self.compressor = TopkCompressor(compress_ratio=k)\n",
    "        self.compressor = compressor\n",
    "        self.type = type\n",
    "\n",
    "    def load(self, payload) -> bool:\n",
    "        if self.type == \"topk\":\n",
    "            values, indices = payload[0]\n",
    "            decompressed_payload = self.compressor.decompress(values, indices, self.model_parameters.shape)\n",
    "\n",
    "        if self.type == \"qsgd\":\n",
    "            n, s, l = payload[0]\n",
    "            decompressed_payload = self.compressor.decompress((n,s,l))\n",
    "        \n",
    "        return super().load([decompressed_payload])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# main, this part we follow the pipeline in pipeline_tutorial.ipynb\n",
    "# But replace the hander and trainer by the above defined for communication compression\n",
    "\n",
    "# configuration\n",
    "import os\n",
    "from opcode import cmp_op\n",
    "from munch import Munch\n",
    "from fedlab.models.mlp import MLP\n",
    "\n",
    "model = MLP(784, 10)\n",
    "args = Munch\n",
    "\n",
    "args.total_client = 100\n",
    "args.alpha = 0.5\n",
    "args.seed = 42\n",
    "args.preprocess = False if os.path.exists(\"../datasets/mnist/fedmnist/train/data2.pkl\") else True\n",
    "args.cuda = True if torch.cuda.is_available() else False\n",
    "args.cmp_op = \"qsgd\" # \"topk, qsgd\"\n",
    "\n",
    "args.k = 0.1 # topk\n",
    "args.bit = 8 # qsgd\n",
    "\n",
    "if args.cmp_op == \"topk\":\n",
    "    compressor = TopkCompressor(args.k)\n",
    "\n",
    "if args.cmp_op == \"qsgd\":\n",
    "    compressor = QSGDCompressor(args.bit)\n",
    "\n",
    "from torchvision import transforms\n",
    "from fedlab.contrib.dataset.partitioned_mnist import PartitionedMNIST\n",
    "\n",
    "fed_mnist = PartitionedMNIST(root=\"../datasets/mnist/\",\n",
    "                             path=\"../datasets/mnist/fedmnist/\",\n",
    "                             num_clients=args.total_client,\n",
    "                             partition=\"noniid-labeldir\",\n",
    "                             dir_alpha=args.alpha,\n",
    "                             seed=args.seed,\n",
    "                             preprocess=args.preprocess,\n",
    "                             download=True,\n",
    "                             verbose=True,\n",
    "                             transform=transforms.Compose([\n",
    "                                 transforms.ToPILImage(),\n",
    "                                 transforms.ToTensor()\n",
    "                             ]))\n",
    "\n",
    "dataset = fed_mnist.get_dataset(0)  # get the 0-th client's dataset\n",
    "dataloader = fed_mnist.get_dataloader(\n",
    "    0,\n",
    "    batch_size=128)  # get the 0-th client's dataset loader with batch size 128\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# client\n",
    "from fedlab.contrib.algorithm.basic_client import SGDSerialClientTrainer, SGDClientTrainer\n",
    "\n",
    "# local train configuration\n",
    "args.epochs = 5\n",
    "args.batch_size = 128\n",
    "args.lr = 0.1\n",
    "\n",
    "trainer = CompressSerialClientTrainer(model, args.total_client,\n",
    "                                 cuda=args.cuda)  # serial trainer\n",
    "# trainer = SGDClientTrainer(model, cuda=True) # single trainer\n",
    "\n",
    "trainer.setup_dataset(fed_mnist)\n",
    "trainer.setup_optim(args.epochs, args.batch_size, args.lr)\n",
    "trainer.setup_compressor(compressor)\n",
    "\n",
    "# server\n",
    "from fedlab.contrib.algorithm.basic_server import SyncServerHandler\n",
    "\n",
    "# global configuration\n",
    "args.com_round = 10\n",
    "args.sample_ratio = 0.1\n",
    "\n",
    "handler = CompressServerHandler(model=model,\n",
    "                            global_round=args.com_round,\n",
    "                            num_clients=args.total_client,\n",
    "                            sample_ratio=args.sample_ratio,\n",
    "                            cuda=args.cuda)\n",
    "handler.setup_compressor(compressor, args.cmp_op)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      ">>> Local training: 100%|██████████| 10/10 [00:04<00:00,  2.28it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss 21.6554, test accuracy 0.2149\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      ">>> Local training: 100%|██████████| 10/10 [00:03<00:00,  2.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss 17.1997, test accuracy 0.5176\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      ">>> Local training: 100%|██████████| 10/10 [00:03<00:00,  2.73it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss 11.8354, test accuracy 0.6876\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      ">>> Local training: 100%|██████████| 10/10 [00:03<00:00,  2.94it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss 9.7127, test accuracy 0.6818\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      ">>> Local training: 100%|██████████| 10/10 [00:03<00:00,  2.66it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss 9.9816, test accuracy 0.6350\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      ">>> Local training: 100%|██████████| 10/10 [00:03<00:00,  2.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss 7.2361, test accuracy 0.7495\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      ">>> Local training: 100%|██████████| 10/10 [00:04<00:00,  2.40it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss 6.7841, test accuracy 0.7664\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      ">>> Local training: 100%|██████████| 10/10 [00:03<00:00,  2.70it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss 6.7317, test accuracy 0.7627\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      ">>> Local training: 100%|██████████| 10/10 [00:03<00:00,  2.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss 4.8795, test accuracy 0.8402\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      ">>> Local training: 100%|██████████| 10/10 [00:04<00:00,  2.40it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss 4.5749, test accuracy 0.8712\n"
     ]
    }
   ],
   "source": [
    "from fedlab.utils.functional import evaluate\n",
    "from fedlab.core.standalone import StandalonePipeline\n",
    "\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader\n",
    "import torchvision\n",
    "\n",
    "class EvalPipeline(StandalonePipeline):\n",
    "    def __init__(self, handler, trainer, test_loader):\n",
    "        super().__init__(handler, trainer)\n",
    "        self.test_loader = test_loader\n",
    "\n",
    "    def main(self):\n",
    "        while self.handler.if_stop is False:\n",
    "            # server side\n",
    "            sampled_clients = self.handler.sample_clients()\n",
    "            broadcast = self.handler.downlink_package\n",
    "\n",
    "            # client side\n",
    "            self.trainer.local_process(broadcast, sampled_clients)\n",
    "            uploads = self.trainer.uplink_package\n",
    "\n",
    "            # server side\n",
    "            for pack in uploads:\n",
    "                self.handler.load(pack)\n",
    "\n",
    "            loss, acc = evaluate(self.handler.model, nn.CrossEntropyLoss(),\n",
    "                                 self.test_loader)\n",
    "            print(f\"Centralized Evaluation round {self.handler.round}: loss {loss:.4f}, test accuracy {acc:.4f}\")\n",
    "\n",
    "\n",
    "test_data = torchvision.datasets.MNIST(root=\"../datasets/mnist/\",\n",
    "                                       train=False,\n",
    "                                       download=True,\n",
    "                                       transform=transforms.ToTensor())\n",
    "test_loader = DataLoader(test_data, batch_size=1024)\n",
    "\n",
    "standalone_eval = EvalPipeline(handler=handler,\n",
    "                               trainer=trainer,\n",
    "                               test_loader=test_loader)\n",
    "standalone_eval.main()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.0 ('fedlab')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "019ae50596e3d4df627f3288be8543f4b17347150bdb9d2aa2e7c637014aee00"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
